#!/usr/bin/python
# coding:utf-8

import pymongo
from bson.int64 import Int64
from bson.objectid import ObjectId


# mongo数据库db
class MongoDB(object):
    # 实例化数据库
    def __init__(self, config):
        self.host = config['host']
        self.port = config['port']
        self.poolsize = config['poolsize']
        self.dbname = config['dbname']
        self.authDB = "admin"
        if "authdb" in config:
            self.authDB = config["authdb"]
        if "dbuser" in config:
            self.dbuser = config["dbuser"]
        if "dbpwd" in config:
            self.dbpwd = config["dbpwd"]
        self.__getConn()

    # 链接数据库
    def __getConn(self, **kwargs):
        if hasattr(self, "mongodb"):
            return self.mongodb

        self.client = pymongo.MongoClient(host=self.host, port=self.port, maxPoolSize=self.poolsize, connect=False)
        self.mongodb = self.client.get_database(self.dbname)
        if hasattr(self, "dbuser") and hasattr(self, "dbpwd"):
            self.client.get_database(self.authDB).authenticate(self.dbuser, self.dbpwd)

        return self.mongodb

    # 转化为objectId数据类型
    def toObjectId(self, oid):
        return ObjectId(oid)

    # 转化为int64数据类型
    def toInt64(self, num):
        return Int64(num)

    # 原始聚合
    def aggregate(self, tb, pipeline):
        _res = self.mongodb[tb].aggregate(pipeline)
        return _res

    # 插入数据
    def insert(self, tb, *arg, **key):
        res = self.mongodb[tb].insert(*arg, **key)
        return res

    # 插入数据
    def insert_many(self, tb, *arg, **key):
        key["bypass_document_validation"] = True
        res = self.mongodb[tb].insert_many(*arg, **key)
        return res

    # 同时修改修改数据
    def update(self, tb, where, data, **key):
        ndata = data
        if (''.join(data.keys()).find('$')) == -1:
            ndata = {}
            ndata['$set'] = data

        res = self.mongodb[tb].update(where, ndata, multi=True, **key)
        return res

    def find1(self, tb, *arg, **key):
        if 'where' in key:
            key['filter'] = key['where']
            del key['where']

        if "fields" in key and type(key["fields"]) == list:
            _fmap = {}
            for ele in key["fields"]:
                if ele != "_id":
                    _fmap[ele] = 1
                else:
                    _fmap[ele] = 0

            key["fields"] = _fmap

        if "fields" in key:
            key["projection"] = key["fields"]
            del key["fields"]

        res = self.mongodb[tb].find_one(*arg, **key)
        return res

    # 查询数据集
    def find(self, tb, *arg, **key):
        if 'where' in key:
            key['filter'] = key['where']
            del key['where']

        if "fields" in key and type(key["fields"]) == list:
            _fmap = {}
            for ele in key["fields"]:
                if ele != "_id":
                    _fmap[ele] = 1
                else:
                    _fmap[ele] = 0

            key["fields"] = _fmap

        if "fields" in key:
            key["projection"] = key["fields"]
            del key["fields"]

        res = list(self.mongodb[tb].find(*arg, **key))
        return res

    # 删除数据
    def delete(self, tb, *arg, **key):
        if 'where' in key:
            key['spec_or_id'] = key['where']
            del key['where']
        res = self.mongodb[tb].remove(*arg, **key)
        return res

    # 计数命令
    def count(self, tb, where=None, **kwargs):
        res = self.mongodb[tb].count(where, **kwargs)
        return res


if __name__ == "__main__":
    pass
